{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Reinforcement Learning in Action\n",
    "### by Alex Zai and Brandon Brown\n",
    "\n",
    "#### Chapter 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
      " 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n",
      " 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]\n"
     ]
    }
   ],
   "source": [
    "import multiprocessing as mp\n",
    "from multiprocess import queues\n",
    "\n",
    "import numpy as np\n",
    "def square(x):\n",
    "    return np.square(x)\n",
    "x = np.arange(64)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mp.cpu_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == '__main__': # added this line for process safety\n",
    "    pool = mp.Pool(8)\n",
    "    squared = pool.map(square, [x[8*i:8*i+8] for i in range(8)])\n",
    "    squared"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In process 0\n",
      "In process 1\n",
      "In process 2\n",
      "In process 3\n",
      "In process 4\n",
      "In process 5\n",
      "In process 6\n",
      "In process 7\n"
     ]
    }
   ],
   "source": [
    "def square(i, x, queue):\n",
    "    print(\"In process {}\".format(i,))\n",
    "\n",
    "queue = mp.Queue()\n",
    "queue.put(np.square(x))\n",
    "processes = []\n",
    "if __name__ == '__main__': #adding this for process safety\n",
    "    x = np.arange(64)\n",
    "    for i in range(8):\n",
    "        start_index = 8*i\n",
    "        proc = mp.Process(target=square,args=(i,x[start_index:start_index+8],\n",
    "                         queue)) \n",
    "        proc.start()\n",
    "        processes.append(proc)\n",
    "\n",
    "    for proc in processes:\n",
    "        proc.join()\n",
    "\n",
    "    for proc in processes:\n",
    "        proc.terminate()\n",
    "\n",
    "    results = []\n",
    "    while not queue.empty():\n",
    "        results.append(queue.get())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([   0,    1,    4,    9,   16,   25,   36,   49,   64,   81,  100,\n",
       "         121,  144,  169,  196,  225,  256,  289,  324,  361,  400,  441,\n",
       "         484,  529,  576,  625,  676,  729,  784,  841,  900,  961, 1024,\n",
       "        1089, 1156, 1225, 1296, 1369, 1444, 1521, 1600, 1681, 1764, 1849,\n",
       "        1936, 2025, 2116, 2209, 2304, 2401, 2500, 2601, 2704, 2809, 2916,\n",
       "        3025, 3136, 3249, 3364, 3481, 3600, 3721, 3844, 3969])]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.3: Pseudocode (not shown)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import numpy as np\n",
    "from torch.nn import functional as F\n",
    "import gym\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "class ActorCritic(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ActorCritic, self).__init__()\n",
    "        self.l1 = nn.Linear(4,25)\n",
    "        self.l2 = nn.Linear(25,50)\n",
    "        self.actor_lin1 = nn.Linear(50,2)\n",
    "        self.l3 = nn.Linear(50,25)\n",
    "        self.critic_lin1 = nn.Linear(25,1)\n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x,dim=0)\n",
    "        y = F.relu(self.l1(x))\n",
    "        y = F.relu(self.l2(y))\n",
    "        actor = F.log_softmax(self.actor_lin1(y),dim=0)\n",
    "        c = F.relu(self.l3(y.detach()))\n",
    "        critic = torch.tanh(self.critic_lin1(c))\n",
    "        return actor, critic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.5 \n",
    "##### NOTE 1: This will not run on its own, you need to run listing 5.6 - 5.8 first then come back and run this cell.\n",
    "##### NOTE 2: This will not record losses for plotting. If you want to record losses, you'll need to create a multiprocessing shared array and modify the `worker` function to write each loss to it. See < https://docs.python.org/3/library/multiprocessing.html > Alternatively, you could use process locks to safely write to a file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6998 0\n"
     ]
    }
   ],
   "source": [
    "MasterNode = ActorCritic()\n",
    "MasterNode.share_memory()\n",
    "processes = []\n",
    "params = {\n",
    "    'epochs':1000,\n",
    "    'n_workers':7,\n",
    "}\n",
    "counter = mp.Value('i',0)\n",
    "if __name__ == '__main__': #adding this for process safety\n",
    "    for i in range(params['n_workers']):\n",
    "        p = mp.Process(target=worker, args=(i,MasterNode,counter,params))\n",
    "        p.start() \n",
    "        processes.append(p)\n",
    "    for p in processes:\n",
    "        p.join()\n",
    "    for p in processes:\n",
    "        p.terminate()\n",
    "    \n",
    "print(counter.value,processes[1].exitcode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def worker(t, worker_model, counter, params):\n",
    "    worker_env = gym.make(\"CartPole-v1\")\n",
    "    worker_env.reset()\n",
    "    worker_opt = optim.Adam(lr=1e-4,params=worker_model.parameters())\n",
    "    worker_opt.zero_grad()\n",
    "    for i in range(params['epochs']):\n",
    "        worker_opt.zero_grad()\n",
    "        values, logprobs, rewards = run_episode(worker_env,worker_model)\n",
    "        actor_loss,critic_loss,eplen = update_params(worker_opt,values,logprobs,rewards)\n",
    "        counter.value = counter.value + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_episode(worker_env, worker_model):\n",
    "    state = torch.from_numpy(worker_env.env.state).float()\n",
    "    values, logprobs, rewards = [],[],[]\n",
    "    done = False\n",
    "    j=0\n",
    "    while (done == False):\n",
    "        j+=1\n",
    "        policy, value = worker_model(state)\n",
    "        values.append(value)\n",
    "        logits = policy.view(-1)\n",
    "        action_dist = torch.distributions.Categorical(logits=logits)\n",
    "        action = action_dist.sample()\n",
    "        logprob_ = policy.view(-1)[action]\n",
    "        logprobs.append(logprob_)\n",
    "        state_, _, done, info = worker_env.step(action.detach().numpy())\n",
    "        state = torch.from_numpy(state_).float()\n",
    "        if done:\n",
    "            reward = -10\n",
    "            worker_env.reset()\n",
    "        else:\n",
    "            reward = 1.0\n",
    "        rewards.append(reward)\n",
    "    return values, logprobs, rewards"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_params(worker_opt,values,logprobs,rewards,clc=0.1,gamma=0.95):\n",
    "    rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1)\n",
    "    logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)\n",
    "    values = torch.stack(values).flip(dims=(0,)).view(-1)\n",
    "    Returns = []\n",
    "    ret_ = torch.Tensor([0])\n",
    "    for r in range(rewards.shape[0]):\n",
    "        ret_ = rewards[r] + gamma * ret_\n",
    "        Returns.append(ret_)\n",
    "    Returns = torch.stack(Returns).view(-1)\n",
    "    Returns = F.normalize(Returns,dim=0)\n",
    "    actor_loss = -1*logprobs * (Returns - values.detach())\n",
    "    critic_loss = torch.pow(values - Returns,2)\n",
    "    loss = actor_loss.sum() + clc*critic_loss.sum()\n",
    "    loss.backward()\n",
    "    worker_opt.step()\n",
    "    return actor_loss, critic_loss, len(rewards)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Supplement\n",
    "##### Test the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\")\n",
    "env.reset()\n",
    "\n",
    "for i in range(100):\n",
    "    state_ = np.array(env.env.state)\n",
    "    state = torch.from_numpy(state_).float()\n",
    "    logits,value = MasterNode(state)\n",
    "    action_dist = torch.distributions.Categorical(logits=logits)\n",
    "    action = action_dist.sample()\n",
    "    state2, reward, done, info = env.step(action.detach().numpy())\n",
    "    if done:\n",
    "        print(\"Lost\")\n",
    "        env.reset()\n",
    "    state_ = np.array(env.env.state)\n",
    "    state = torch.from_numpy(state_).float()\n",
    "    env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### N-step actor-critic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 5.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_episode(worker_env, worker_model, N_steps=10):\n",
    "    raw_state = np.array(worker_env.env.state)\n",
    "    state = torch.from_numpy(raw_state).float()\n",
    "    values, logprobs, rewards = [],[],[]\n",
    "    done = False\n",
    "    j=0\n",
    "    G=torch.Tensor([0])\n",
    "    while (j < N_steps and done == False):\n",
    "        j+=1\n",
    "        policy, value = worker_model(state)\n",
    "        values.append(value)\n",
    "        logits = policy.view(-1)\n",
    "        action_dist = torch.distributions.Categorical(logits=logits)\n",
    "        action = action_dist.sample()\n",
    "        logprob_ = policy.view(-1)[action]\n",
    "        logprobs.append(logprob_)\n",
    "        state_, _, done, info = worker_env.step(action.detach().numpy())\n",
    "        state = torch.from_numpy(state_).float()\n",
    "        if done:\n",
    "            reward = -10\n",
    "            worker_env.reset()\n",
    "        else:\n",
    "            reward = 1.0\n",
    "            G = value.detach()\n",
    "        rewards.append(reward)\n",
    "    return values, logprobs, rewards, G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
